{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1 align=\"center\">Data Augmentation for Deep Learning</h1>\n",
    "\n",
    "\n",
    "**Summary:**\n",
    "1. SimpleITK supports a variety of spatial transformations (global or local) that can be used to augment your dataset via resampling directly from the original images (which vary in size).\n",
    "2. Resampling to a uniform size can be done either by specifying the desired sizes resulting in non-isotropic pixel spacings (most often) or by specifying an isotropic pixel spacing and one of the image sizes (width,height,depth). \n",
    "3. SimpleITK supports a variety of intensity transformations (blurring, adding noise etc.) that can be used to augment your dataset after it has been resampled to the size expected by your network.\n",
    "\n",
    "This notebook illustrates the use of SimpleITK to perform data augmentation for deep learning. Note that the code is written so that the relevant functions work for both 2D and 3D images without modification.\n",
    "\n",
    "Data augmentation is a model based approach for enlarging your training set. The problem being addressed is that the original dataset is not sufficiently representative of the general population of images. As a consequence, if we only train on the original dataset the resulting network will not generalize well to the population (overfitting). \n",
    "\n",
    "Using a model of the variations found in the general population and the existing dataset we generate additional images in the hope of capturing the population variability. Note that if the model you use is incorrect you can cause harm, you are generating observations that do not occur in the general population and are optimizing a function to fit them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "library(SimpleITK)\n",
    "\n",
    "source(\"downloaddata.R\")\n",
    "\n",
    "OUTPUT_DIR <- 'Output'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Before we start, a word of caution\n",
    "\n",
    "**Whenever you sample there is potential for aliasing (Nyquist theorem).**\n",
    "\n",
    "In many cases, data prepared for use with a deep learning network is resampled to a fixed size. When we perform data augmentation via spatial transformations we also perform resampling. \n",
    "\n",
    "Admittedly, the example below is exaggerated to illustrate the point, but it serves as a reminder that you may want to consider smoothing your images prior to resampling. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "simpleitk_error_allowed": "Exception in SITK"
    },
   "outputs": [],
   "source": [
    "# The image we will resample (a grid).\n",
    "grid_image <- GridSource(outputPixelType=\"sitkUInt16\", size=c(512,512), \n",
    "                             sigma=c(0.1,0.1), gridSpacing=c(20.0,20.0))\n",
    "Show(grid_image, \"original grid image\")\n",
    "\n",
    "# The spatial definition of the images we want to use in a deep learning framework (smaller than the original). \n",
    "new_size <- c(100, 100)\n",
    "reference_image <- Image(new_size, grid_image$GetPixelID())\n",
    "reference_image$SetOrigin(grid_image$GetOrigin())\n",
    "reference_image$SetDirection(grid_image$GetDirection())\n",
    "reference_image$SetSpacing(grid_image$GetSize()*grid_image$GetSpacing()/new_size)\n",
    "\n",
    "# Resample without any smoothing.\n",
    "Show(Resample(grid_image, reference_image) , \"resampled without smoothing\")\n",
    "\n",
    "# Resample after Gaussian smoothing.\n",
    "Show(Resample(SmoothingRecursiveGaussian(grid_image, 2.0), reference_image), \"resampled with smoothing\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data\n",
    "\n",
    "Load the images. You can work through the notebook using either the original 3D images or 2D slices from the original volumes. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "simpleitk_error_allowed": "Exception in SITK"
   },
   "outputs": [],
   "source": [
    "data <- list(ReadImage(fetch_data(\"nac-hncma-atlas2013-Slicer4Version/Data/A1_grayT1.nrrd\"), \"sitkFloat32\"),\n",
    "             ReadImage(fetch_data(\"vm_head_mri.mha\"), \"sitkFloat32\"))\n",
    "\n",
    "# Comment out the following line if you want to work in 3D. Note that in 3D some of the notebook visualizations are \n",
    "# disabled. \n",
    "data <- list(data[[1]][,161,], data[[2]][,,18])\n",
    "\n",
    "invisible(lapply(data, Show))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The original data often needs to be modified. In this example we would like to crop the images so that we only keep the informative regions. We can readily separate the foreground and background using an appropriate threshold, in our case we use Otsu's threshold selection method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "simpleitk_error_allowed": "Exception in SITK"
   },
   "outputs": [],
   "source": [
    "# Use Otsu's threshold estimator to separate background and foreground. In medical imaging the background is\n",
    "# usually air. Then crop the image using the foreground's axis aligned bounding box.\n",
    "# Args:\n",
    "#   image (SimpleITK image): An image where the anatomy and background intensities form a bi-modal distribution\n",
    "#                           (the assumption underlying Otsu's method.)\n",
    "# Return:\n",
    "#   Cropped image based on foreground's axis aligned bounding box.  \n",
    "threshold_based_crop <- function(image) {\n",
    "  # Set pixels that are in [min_intensity,otsu_threshold] to inside_value, values above otsu_threshold are\n",
    "  # set to outside_value. The anatomy has higher intensity values than the background, so it is outside.\n",
    "  inside_value <- 0\n",
    "  outside_value <- 255\n",
    "  label_shape_filter <- LabelShapeStatisticsImageFilter()\n",
    "  label_shape_filter$Execute( OtsuThreshold(image, inside_value, outside_value) )\n",
    "  bounding_box <- label_shape_filter$GetBoundingBox(outside_value)\n",
    "  # The bounding box's first \"dim\" entries are the starting index and last \"dim\" entries the size\n",
    "  vec_len <- length(bounding_box)\n",
    "  return(RegionOfInterest(image, bounding_box[(vec_len/2 + 1) : vec_len], bounding_box[1:(vec_len/2)]))\n",
    "}\n",
    "    \n",
    "modified_data <- lapply(data, threshold_based_crop)\n",
    "\n",
    "invisible(lapply(modified_data, Show))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point we select the images we want to work with, skip the following cell if you want to work with the original data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data = modified_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Augmentation using spatial transformations\n",
    "\n",
    "We next illustrate the generation of images by specifying a list of transformation parameter values representing a sampling of the transformation's parameter space.\n",
    "\n",
    "The code below is agnostic to the specific transformation and it is up to the user to specify a valid list of transformation parameters (correct number of parameters and correct order). \n",
    "\n",
    "In most cases we can easily specify a regular grid in parameter space by specifying ranges of values for each of the parameters. In some cases specifying parameter values may be less intuitive (i.e. versor representation of rotation)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utility methods\n",
    "\n",
    "Utilities for sampling a parameter space using a regular grid in a convenient manner (special care for 3D similarity).  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Create a list representing a regular sampling of the parameter space.\n",
    "# Args:two or more vectors representing parameter values. The order\n",
    "#      of the vectors should match the ordering of the SimpleITK transformation\n",
    "#      parameterization (e.g. Similarity2DTransform: scaling, rotation, tx, ty)\n",
    "# Return:\n",
    "# List of  vectors representing the regular grid sampling.\n",
    "parameter_space_regular_grid_sampling <- function(...) {\n",
    "    df <- expand.grid(list(...))\n",
    "    return(lapply(split(df,seq_along(df[,1])), as.vector))\n",
    "}\n",
    "\n",
    "\n",
    "# Create a list representing a regular sampling of the 3D similarity transformation parameter space. As the\n",
    "# SimpleITK rotation parameterization uses the vector portion of a versor we don't have an\n",
    "# intuitive way of specifying rotations. We therefor use the ZYX Euler angle parametrization and convert to\n",
    "# versor.\n",
    "# Args:\n",
    "#   thetaX, thetaY, thetaZ: vectors with the Euler angle values to use.\n",
    "#   tx, ty, tz: vectors with the translation values to use.\n",
    "#   scale: vector with the scale values to use.\n",
    "#    Return:\n",
    "#        List of vectors representing the parameter space sampling (vx,vy,vz,tx,ty,tz,s).\n",
    "similarity3D_parameter_space_regular_sampling <- function(thetaX, thetaY, thetaZ, tx, ty, tz, scale) {\n",
    "  euler_sampling <- parameter_space_regular_grid_sampling(thetaX, thetaY, thetaZ, tx, ty, tz, scale)\n",
    "  # replace Euler angles with quaternion vector\n",
    "  for(i in seq_along(euler_sampling)) {\n",
    "    euler_sampling[[i]][1:3] <- eul2quat(as.double(euler_sampling[[i]][1]),\n",
    "                                         as.double(euler_sampling[[i]][2]),\n",
    "                                         as.double(euler_sampling[[i]][3]))\n",
    "  }\n",
    "  return(euler_sampling)\n",
    "}\n",
    "\n",
    "# Translate between Euler angle (ZYX) order and quaternion representation of a rotation.\n",
    "# Args:\n",
    "#   ax: X rotation angle in radians.\n",
    "#   ay: Y rotation angle in radians.\n",
    "#   az: Z rotation angle in radians.\n",
    "#   atol: tolerance used for stable quaternion computation (qs==0 within this tolerance).\n",
    "# Return:\n",
    "#      Vector with three entries representing the vectorial component of the quaternion.\n",
    "eul2quat <- function(ax, ay, az, atol=1e-8) {\n",
    "    # Create rotation matrix using ZYX Euler angles and then compute quaternion using entries.\n",
    "    cx <- cos(ax)\n",
    "    cy <- cos(ay)\n",
    "    cz <- cos(az)\n",
    "    sx <- sin(ax)\n",
    "    sy <- sin(ay)\n",
    "    sz <- sin(az)\n",
    "    r <- array(0,c(3,3))\n",
    "    r[1,1] <- cz*cy\n",
    "    r[1,2] <- cz*sy*sx - sz*cx\n",
    "    r[1,3] <- cz*sy*cx+sz*sx\n",
    "\n",
    "    r[2,1] <- sz*cy\n",
    "    r[2,2] <- sz*sy*sx + cz*cx\n",
    "    r[2,3] <- sz*sy*cx - cz*sx\n",
    "\n",
    "    r[3,1] <- -sy\n",
    "    r[3,2] <- cy*sx\n",
    "    r[3,3] <- cy*cx\n",
    "\n",
    "    # Compute quaternion:\n",
    "    qs <- 0.5*sqrt(r[1,1] + r[2,2] + r[3,3] + 1)\n",
    "    qv <- c(0,0,0)\n",
    "    # If the scalar component of the quaternion is close to zero, we\n",
    "    # compute the vector part using a numerically stable approach\n",
    "    if(abs(qs - 0.0) < atol) {\n",
    "        i <- which.max(c(r[1,1], r[2,2], r[3,3]))\n",
    "        j <- (i+1)%%3 + 1\n",
    "        k <- (j+1)%%3 + 1\n",
    "        w <- sqrt(r[i,i] - r[j,j] - r[k,k] + 1)\n",
    "        qv[i] <- 0.5*w\n",
    "        qv[j] <- (r[i,j] + r[j,i])/(2*w)\n",
    "        qv[k] <- (r[i,k] + r[k,i])/(2*w)\n",
    "    } else {\n",
    "        denom <- 4*qs\n",
    "        qv[1] <- (r[3,2] - r[2,3])/denom;\n",
    "        qv[2] <- (r[1,3] - r[3,1])/denom;\n",
    "        qv[3] <- (r[2,1] - r[1,2])/denom;\n",
    "    }\n",
    "    return(qv)\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create reference domain \n",
    "\n",
    "All input images will be resampled onto the reference domain.\n",
    "\n",
    "This domain is defined by two constraints: the number of pixels per dimension and the physical size we want the reference domain to occupy. The former is associated with the computational constraints of deep learning where using a small number of pixels is desired. The later is associated with the SimpleITK concept of an image, it occupies a  region in physical space which should be large enough to encompass the object of interest."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dimension <- data[[1]]$GetDimension()\n",
    "\n",
    "# Physical image size corresponds to the largest physical size in the training set, or any other arbitrary size.\n",
    "reference_physical_size <- numeric(dimension)\n",
    "physical_sizes <- lapply(data, function(image){return ((image$GetSize()-1)*image$GetSpacing())})\n",
    "reference_physical_size <- apply(do.call(rbind,physical_sizes),2,max)\n",
    "\n",
    "# Create the reference image with a zero origin, identity direction cosine matrix and dimension     \n",
    "reference_origin <- numeric(dimension)\n",
    "reference_direction <- as.vector(t(diag(dimension)))\n",
    "\n",
    "# Select arbitrary number of pixels per dimension, smallest size that yields desired results \n",
    "# or the required size of a pretrained network (e.g. VGG-16 224x224), transfer learning. This will \n",
    "# often result in non-isotropic pixel spacing.\n",
    "reference_size <- rep(128, dimension)\n",
    "reference_spacing <- reference_physical_size / (reference_size-1)\n",
    "\n",
    "# Another possibility is that you want isotropic pixels, then you can specify the image size for one of\n",
    "# the axes and the others are determined by this choice. Below we choose to set the x axis to 128 and the\n",
    "# spacing set accordingly. \n",
    "# Uncomment the following lines to use this strategy.\n",
    "#reference_size_x <- 128\n",
    "#reference_spacing <- rep(reference_physical_size[1]/(reference_size_x-1),dimension)\n",
    "#reference_size <- as.integer(reference_physical_size / reference_spacing + 1)\n",
    "\n",
    "reference_image <- Image(reference_size, data[[1]]$GetPixelID())\n",
    "reference_image$SetOrigin(reference_origin)\n",
    "reference_image$SetSpacing(reference_spacing)\n",
    "reference_image$SetDirection(reference_direction)\n",
    "\n",
    "# Always use the TransformContinuousIndexToPhysicalPoint to compute an indexed point's physical coordinates as \n",
    "# this takes into account size, spacing and direction cosines. For the vast majority of images the direction \n",
    "# cosines are the identity matrix, but when this isn't the case simply multiplying the central index by the \n",
    "# spacing will not yield the correct coordinates resulting in a long debugging session. \n",
    "reference_center <- reference_image$TransformContinuousIndexToPhysicalPoint(reference_image$GetSize()/2.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data generation\n",
    "\n",
    "Once we have a reference domain we can augment the data using any of the SimpleITK global domain transformations. In this notebook we use a similarity transformation (the generate_images function is agnostic to this specific choice).\n",
    "\n",
    "Note that you also need to create the labels for your augmented images. If these are just classes then your processing is minimal. If you are dealing with segmentation you will also need to transform the segmentation labels so that they match the transformed image. The following function easily accommodates for this, just provide the labeled image as input and use the sitk.sitkNearestNeighbor interpolator so that you do not introduce labels that were not in the original segmentation.   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#Generate the resampled images based on the given transformations.\n",
    "#  Args:\n",
    "#    original_image (SimpleITK image): The image which we will resample and transform.\n",
    "#    reference_image (SimpleITK image): The image onto which we will resample.\n",
    "#    T0 (SimpleITK transform): Transformation which maps points from the reference image coordinate system \n",
    "#            to the original_image coordinate system.\n",
    "#    T_aug (SimpleITK transform): Map points from the reference_image coordinate system back onto itself using the\n",
    "#           given transformation_parameters. The reason we use this transformation as a parameter\n",
    "#           is to allow the user to set its center of rotation to something other than zero.\n",
    "#    transformation_parameters (List of lists): parameter values which we use T_aug.SetParameters().\n",
    "#    output_prefix (string): output file name prefix (file name: output_prefix_p1_p2_..pn_.output_suffix).\n",
    "#    output_suffix (string): output file name suffix (file name: output_prefix_p1_p2_..pn_.output_suffix).\n",
    "#    interpolator: One of the SimpleITK interpolators.\n",
    "#    default_intensity_value: The value to return if a point is mapped outside the original_image domain\n",
    "augment_images_spatial <- function(original_image, reference_image, T0, T_aug, transformation_parameters,\n",
    "                    output_prefix, output_suffix,\n",
    "                    interpolator = \"sitkLinear\", default_intensity_value = 0.0) {\n",
    "    all_images <- lapply(transformation_parameters, \n",
    "                         function(current_parameters) {\n",
    "                           T_aug$SetParameters(current_parameters)\n",
    "                           # Augmentation is done in the reference image space, so we first map the points from \n",
    "                           # the reference image space back onto itself T_aug (e.g. rotate the reference image)\n",
    "                           # and then we map to the original image space T0.\n",
    "                           T_all <- Transform(T0)\n",
    "                           T_all$AddTransform(T_aug)\n",
    "                           aug_image <- Resample(original_image, reference_image, T_all,\n",
    "                                                 interpolator, default_intensity_value)\n",
    "                           WriteImage(aug_image, paste0(output_prefix,\"_\",paste0(current_parameters, collapse=\"_\"),\n",
    "                                                        \"_.\",output_suffix))\n",
    "                           return(aug_image)\n",
    "                         })\n",
    "    return(all_images) # Used only for display purposes in this notebook.\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Before we can use the generate_images function we need to compute the transformation which will map points between the reference image and the current image as shown in the code cell below. \n",
    "\n",
    "Note that it is very easy to generate large amounts of data, the calls to np.linspace with $m$ parameters each having $n$ values results in $n^m$ images, so don't forget that these images are also saved to disk. **If you run the code below for 3D data you will generate 6561 volumes ($3^7$ parameter combinations times 3 volumes).**  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "simpleitk_error_allowed": "Exception in SITK"
   },
   "outputs": [],
   "source": [
    "if(dimension == 2) {\n",
    "  # The parameters are scale (+-10%), rotation angle (+-10 degrees), x translation, y translation\n",
    "  transformation_parameters_list = parameter_space_regular_grid_sampling(seq(0.9,1.1,0.1),\n",
    "                                                                         seq(-pi/18.0,pi/18.0,pi/18.0),\n",
    "                                                                         seq(-pi/18.0,pi/18.0,pi/18.0),\n",
    "                                                                         seq(-10,10,10),\n",
    "                                                                         seq(-10,10,10))\n",
    "  aug_transform <- Similarity2DTransform()\n",
    "} else {   \n",
    "  transformation_parameters_list = similarity3D_parameter_space_regular_sampling(seq(-pi/18.0,pi/18.0,pi/18.0),\n",
    "                                                                                 seq(-pi/18.0,pi/18.0,pi/18.0),\n",
    "                                                                                 seq(-pi/18.0,pi/18.0,pi/18.0),\n",
    "                                                                                 seq(-10,10,10),\n",
    "                                                                                 seq(-10,10,10),\n",
    "                                                                                 seq(-10,10,10),\n",
    "                                                                                 seq(0.9,1.1,0.1))        \n",
    "  aug_transform <- Similarity3DTransform()\n",
    "}\n",
    "        \n",
    "all_images <- lapply(seq_along(data), \n",
    "                     function(i, images){\n",
    "                       img <- images[[i]]\n",
    "                       # Transform which maps from the reference_image to the current img with the translation mapping the image\n",
    "                       # origins to each other.\n",
    "                       transform <- AffineTransform(dimension)\n",
    "                       transform$SetMatrix(img$GetDirection())\n",
    "                       transform$SetTranslation(img$GetOrigin() - reference_origin)\n",
    "                       # Modify the transformation to align the centers of the original and reference image instead \n",
    "                       # of their origins.\n",
    "                       centering_transform <- TranslationTransform(dimension)\n",
    "                       img_center <- img$TransformContinuousIndexToPhysicalPoint(img$GetSize()/2.0)\n",
    "                       centering_transform$SetOffset(transform$GetInverse()$TransformPoint(img_center) - \n",
    "                                                     reference_center)\n",
    "                       centered_transform <- Transform(transform)\n",
    "                       centered_transform$AddTransform(centering_transform)\n",
    "\n",
    "                       # Set the augmenting transform's center so that rotation is around the image center.\n",
    "                       aug_transform$SetCenter(reference_center)\n",
    "                       return(augment_images_spatial(img, reference_image, centered_transform, \n",
    "                                                     aug_transform, transformation_parameters_list, \n",
    "                                                     file.path(OUTPUT_DIR, paste0('spatial_aug',i)), 'mha'))\n",
    "                     }, \n",
    "                     data)\n",
    "if(dimension==2)\n",
    "  # For each 2D image,stack the augmentation images into a 3D volume and display.\n",
    "  invisible(lapply(all_images, function(images) {Show(JoinSeries(images))}))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## What about flipping\n",
    "\n",
    "Reflection using SimpleITK can be done in one of several ways:\n",
    "1. Use an affine transform with the matrix component set to a reflection matrix. The columns of the matrix correspond to the $\\mathbf{x}, \\mathbf{y}$ and $\\mathbf{z}$ axes. The reflection matrix is constructed using the plane, 3D,  or axis, 2D, which we want to reflect through with the standard basis vectors, $\\mathbf{e}_i, \\mathbf{e}_j$, and the remaining basis vector set to $-\\mathbf{e}_k$.  \n",
    "    * Reflection about $xy$ plane: $[\\mathbf{e}_1, \\mathbf{e}_2, -\\mathbf{e}_3]$.\n",
    "    * Reflection about $xz$ plane: $[\\mathbf{e}_1, -\\mathbf{e}_2, \\mathbf{e}_3]$.\n",
    "    * Reflection about $yz$ plane: $[-\\mathbf{e}_1, \\mathbf{e}_2, \\mathbf{e}_3]$.\n",
    "2. Use the native slicing operator(e.g. img[,seq(img.GetHeight(),1,-1),]), or the FlipImageFilter after the image is resampled onto the reference image grid. \n",
    "\n",
    "We prefer option 1 as it is computationally more efficient. It combines all transformation prior to resampling, while the other approach performs resampling onto the reference image grid followed by the reflection operation. An additional difference is that using slicing or the FlipImageFilter will also modify the image origin while the resampling approach keeps the spatial location of the reference image origin intact. This minor difference is of no concern in deep learning as the content of the images is the same, but in SimpleITK two images are considered equivalent iff their content and spatial extent are the same.\n",
    "\n",
    "The following cell corresponds to the preferred option, using an affine transformation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "simpleitk_error_allowed": "Exception in SITK"
   },
   "outputs": [],
   "source": [
    "flipped_images <- lapply(data, \n",
    "       function(img) {\n",
    "         # Compute the transformation which maps between the reference and current image (same as done above).\n",
    "         transform <- AffineTransform(dimension)\n",
    "         transform$SetMatrix(img$GetDirection())\n",
    "         transform$SetTranslation(img$GetOrigin() - reference_origin)\n",
    "         centering_transform <- TranslationTransform(dimension)\n",
    "         img_center <- img$TransformContinuousIndexToPhysicalPoint(img$GetSize()/2.0)\n",
    "         centering_transform$SetOffset(transform$GetInverse()$TransformPoint(img_center) - reference_center)\n",
    "         centered_transform <- Transform(transform)\n",
    "         centered_transform$AddTransform(centering_transform)\n",
    "    \n",
    "         flipped_transform <- AffineTransform(dimension)    \n",
    "         flipped_transform$SetCenter(reference_image$TransformContinuousIndexToPhysicalPoint(reference_image$GetSize()/2.0))\n",
    "         if(dimension==2) { # matrices in SimpleITK specified in row major order\n",
    "           flipped_transform$SetMatrix(c(1,0,0,-1))\n",
    "         } else {\n",
    "           flipped_transform$SetMatrix(c(1,0,0,0,-1,0,0,0,1))\n",
    "         }\n",
    "         centered_transform$AddTransform(flipped_transform)    \n",
    "         # Resample onto the reference image \n",
    "         return(Resample(img, reference_image, centered_transform, \"sitkLinear\", 0.0))\n",
    "       }\n",
    "      )\n",
    "invisible(lapply(flipped_images,Show))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Radial Distortion\n",
    "\n",
    "Some 2D medical imaging modalities, such as endoscopic video and X-ray images acquired with C-arms using image intensifiers, exhibit radial distortion. The common model for such distortion was described by Brown [\"Close-range camera calibration\", Photogrammetric Engineering, 37(8):855–866, 1971]:\n",
    "$$\n",
    "\\mathbf{p}_u = \\mathbf{p}_d + (\\mathbf{p}_d-\\mathbf{p}_c)(k_1r^2 + k_2r^4 + k_3r^6 + \\ldots)\n",
    "$$\n",
    "\n",
    "where:\n",
    "* $\\mathbf{p}_u$ is a point in the undistorted image\n",
    "* $\\mathbf{p}_d$ is a point in the distorted image\n",
    "* $\\mathbf{p}_c$ is the center of distortion\n",
    "* $r = \\|\\mathbf{p}_d-\\mathbf{p}_c\\|$\n",
    "* $k_i$ are coefficients of the radial distortion\n",
    "\n",
    "\n",
    "Using SimpleITK operators we represent this transformation using a deformation field as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "simpleitk_error_allowed": "Exception in SITK"
   },
   "outputs": [],
   "source": [
    "radial_distort <- function(image, k1, k2, k3, distortion_center=NULL) {\n",
    "  c <- distortion_center\n",
    "  if(is.null(c)) { # The default distortion center coincides with the image center\n",
    "    c <- image$TransformContinuousIndexToPhysicalPoint(image$GetSize()/2.0)\n",
    "  }\n",
    "  # Compute the vector image (p_d - p_c) \n",
    "  delta_image <- PhysicalPointSource( \"sitkVectorFloat64\", image$GetSize(), image$GetOrigin(), image$GetSpacing(), image$GetDirection())\n",
    "  delta_image_list <- lapply(seq_along(c), function(i,center) {return(VectorIndexSelectionCast(delta_image,i-1) - center[i])},c)\n",
    "\n",
    "  # Compute the radial distortion expression  \n",
    "  r2_image <- NaryAdd(lapply(delta_image_list, function(img){return(img**2)}))\n",
    "  r4_image <- r2_image**2\n",
    "  r6_image <- r2_image*r4_image\n",
    "  disp_image <- k1*r2_image + k2*r4_image + k3*r6_image\n",
    "  displacement_image <- Compose(lapply(delta_image_list, function(img){return(img*disp_image)}))\n",
    "\n",
    "  displacement_field_transform <- DisplacementFieldTransform(displacement_image)\n",
    "  return(Resample(image, image, displacement_field_transform))\n",
    "}\n",
    "                                       \n",
    "k1 = 0.00001\n",
    "k2 = 0.0000000000001\n",
    "k3 = 0.0000000000001\n",
    "original_image <- data[[1]]\n",
    "distorted_image <- radial_distort(original_image, k1, k2, k3)\n",
    "# Use a grid image to highlight the distortion.\n",
    "grid_image <- GridSource(outputPixelType=original_image$GetPixelID(), size=original_image$GetSize(), \n",
    "                         sigma=rep(0.1, dimension), gridSpacing=rep(20.0,dimension))\n",
    "grid_image$CopyInformation(original_image)\n",
    "distorted_grid <- radial_distort(grid_image, k1, k2, k3)\n",
    "Show(Tile(c(original_image, distorted_image, distorted_grid)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transferring deformations - exercise for the interested reader\n",
    "\n",
    "Using SimpleITK we can readily transfer deformations from a spatio-temporal data set to another spatial data set to simulate temporal behavior. Case in point, using a 4D (3D+time) CT of the thorax we can estimate the respiratory motion using non-rigid registration and Free Form Deformation or displacement field transformations. We can then register a new spatial data set to the original spatial CT (non-rigidly) followed by application of the temporal deformations.\n",
    "\n",
    "Note that unlike the arbitrary spatial transformations we used for data-augmentation above this approach is more computationally expensive as it involves multiple non-rigid registrations. Also note that as the goal is to use the estimated transformations to create plausible deformations you may be able to relax the required registration accuracy."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Augmentation using intensity modifications\n",
    "\n",
    "SimpleITK has many filters that are potentially relevant for data augmentation via modification of intensities. For example:\n",
    "* Image smoothing, always read the documentation carefully, similar filters use use different parametrization $\\sigma$ vs. variance ($\\sigma^2$):\n",
    "  * [Discrete Gaussian](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1DiscreteGaussianImageFilter.html)\n",
    "  * [Recursive Gaussian](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1RecursiveGaussianImageFilter.html)\n",
    "  * [Smoothing Recursive Gaussian](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1SmoothingRecursiveGaussianImageFilter.html)\n",
    "\n",
    "* Edge preserving image smoothing:\n",
    "  * [Bilateral image filtering](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1BilateralImageFilter.html), edge preserving smoothing.\n",
    "  * [Median filtering](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1MedianImageFilter.html)\n",
    "\n",
    "* Adding noise to your images:\n",
    "  * [Additive Gaussian](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1AdditiveGaussianNoiseImageFilter.html)\n",
    "  * [Salt and Pepper / Impulse](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1SaltAndPepperNoiseImageFilter.html)\n",
    "  * [Shot/Poisson](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1ShotNoiseImageFilter.html)\n",
    "  * [Speckle/multiplicative](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1SpeckleNoiseImageFilter.html)\n",
    "  \n",
    "* [Adaptive Histogram Equalization](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1AdaptiveHistogramEqualizationImageFilter.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#Generate intensity modified images from the originals.\n",
    "# Args:\n",
    "#   image_list (containing SimpleITK images): The images whose intensities we modify.\n",
    "#   output_prefix (string): output file name prefix (file name: output_prefixi_FilterName.output_suffix).\n",
    "#   output_suffix (string): output file name suffix (file name: output_prefixi_FilterName.output_suffix).\n",
    "augment_images_intensity <- function(image_list, output_prefix, output_suffix) {\n",
    "\n",
    "  # Create a list of intensity modifying filters, which we apply to the given images\n",
    "  num_filters <- 10\n",
    "  index <- 1\n",
    "  filter_list <- vector(\"list\",num_filters)\n",
    "    \n",
    "  # Smoothing filters\n",
    "  filter_list[[index]] <- SmoothingRecursiveGaussianImageFilter()\n",
    "  filter_list[[index]]$SetSigma(2.0)\n",
    "  index = index + 1\n",
    "    \n",
    "  filter_list[[index]] <- DiscreteGaussianImageFilter()\n",
    "  filter_list[[index]]$SetVariance(4.0)\n",
    "  index = index + 1\n",
    "    \n",
    "  filter_list[[index]] <- BilateralImageFilter()\n",
    "  filter_list[[index]]$SetDomainSigma(4.0)\n",
    "  filter_list[[index]]$SetRangeSigma(8.0)\n",
    "  index = index + 1\n",
    "  \n",
    "  filter_list[[index]] <- MedianImageFilter()\n",
    "  filter_list[[index]]$SetRadius(8)\n",
    "  index = index + 1\n",
    "    \n",
    "  # Noise filters using default settings\n",
    "    \n",
    "  # Filter control via SetMean, SetStandardDeviation.\n",
    "  filter_list[[index]] <- AdditiveGaussianNoiseImageFilter()\n",
    "  index = index + 1\n",
    "    \n",
    "  # Filter control via SetProbability\n",
    "  filter_list[[index]] <- SaltAndPepperNoiseImageFilter()\n",
    "  index = index + 1\n",
    "    \n",
    "  # Filter control via SetScale\n",
    "  filter_list[[index]] <- ShotNoiseImageFilter()\n",
    "  index = index + 1\n",
    "    \n",
    "  # Filter control via SetStandardDeviation\n",
    "  filter_list[[index]] <- SpeckleNoiseImageFilter()\n",
    "  index = index + 1\n",
    "    \n",
    "  filter_list[[index]] <- AdaptiveHistogramEqualizationImageFilter()\n",
    "  filter_list[[index]]$SetAlpha(1.0)\n",
    "  filter_list[[index]]$SetBeta(0.0)\n",
    "  index = index + 1\n",
    "    \n",
    "  filter_list[[index]] <- AdaptiveHistogramEqualizationImageFilter()\n",
    "  filter_list[[index]]$SetAlpha(0.0)\n",
    "  filter_list[[index]]$SetBeta(1.0)\n",
    "  index = index + 1\n",
    "  \n",
    "  # Used only for display purposes in this notebook.\n",
    "  aug_image_lists <- lapply(seq_along(image_list), function(i, images) {\n",
    "                       lapply(filter_list, function(filter) {\n",
    "                         aug_image<- filter$Execute(images[[i]])\n",
    "                         WriteImage(aug_image, paste0(output_prefix, i, '_',\n",
    "                                    filter$GetName(), '.', output_suffix))\n",
    "                         return(aug_image)\n",
    "                       })\n",
    "                      },\n",
    "                      image_list)  \n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Modify the intensities of the original images using the set of SimpleITK filters described above. If we are working with 2D images the results will be displayed inline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "simpleitk_error_allowed": "Exception in SITK"
   },
   "outputs": [],
   "source": [
    "intensity_augmented_images <- augment_images_intensity(data, file.path(OUTPUT_DIR, 'intensity_aug'), 'mha')\n",
    "\n",
    "# in 2D we join all of the images into 3D volumes which we use for display.\n",
    "if(dimension==2) {\n",
    "    all_volumes <- lapply(intensity_augmented_images, \n",
    "                          function(image_list){\n",
    "                            JoinSeries(lapply(image_list, function(img){Cast(img, \"sitkFloat32\")}))      \n",
    "                          }\n",
    "                         )\n",
    "    invisible(lapply(all_volumes, Show))\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, you can easily create intensity variations that are specific to your domain, such as the spatially varying multiplicative and additive transformation shown below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "simpleitk_error_allowed": "Exception in SITK"
   },
   "outputs": [],
   "source": [
    "#\n",
    "#Modify the intensities using multiplicative and additive Gaussian bias fields.\n",
    "#\n",
    "mult_and_add_intensity_fields <- function(original_image) {\n",
    "  # Gaussian image with same meta-information as original (size, spacing, direction cosine)\n",
    "  # Sigma is half the image's physical size and mean is the center of the image. \n",
    "  g_mult = GaussianSource(original_image$GetPixelID(),\n",
    "                          original_image$GetSize(),\n",
    "                          (original_image$GetSize() - 1)*original_image$GetSpacing()/2.0,\n",
    "                          original_image$TransformContinuousIndexToPhysicalPoint(original_image$GetSize()/2.0),\n",
    "                          25,\n",
    "                          original_image$GetOrigin(),\n",
    "                          original_image$GetSpacing(),\n",
    "                          original_image$GetDirection())\n",
    "\n",
    "  # Gaussian image with same meta-information as original (size, spacing, direction cosine)\n",
    "  # Sigma is 1/8 the image's physical size and mean is at 1/16 of the size \n",
    "  g_add = GaussianSource(original_image$GetPixelID(),\n",
    "                         original_image$GetSize(),\n",
    "                         (original_image$GetSize() - 1)*original_image$GetSpacing()/8.0,\n",
    "                         original_image$TransformContinuousIndexToPhysicalPoint(original_image$GetSize()/16.0),\n",
    "                         25,\n",
    "                         original_image$GetOrigin(),\n",
    "                         original_image$GetSpacing(),\n",
    "                         original_image$GetDirection())\n",
    "    \n",
    "  return(g_mult*original_image+g_add)\n",
    "}\n",
    "invisible(lapply(data, function(img){Show(mult_and_add_intensity_fields(img))}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "R",
   "language": "R",
   "name": "ir"
  },
  "language_info": {
   "codemirror_mode": "r",
   "file_extension": ".r",
   "mimetype": "text/x-r-source",
   "name": "R",
   "pygments_lexer": "r",
   "version": "3.4.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
